import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import random
import csv
from itertools import product
import json
## Objective, gradients, and compressor

def generate_functions(num_nodes, num_dim, zeta):
    A = np.array([1 / np.sqrt(num_nodes) * np.eye(num_dim) * (i + 1) for i in range(0, num_nodes)])
    # B = np.zeros(num_nodes, num_dim)
    # A.shape = (num_nodes, num_dim, num_dim)
    # B.shape = (num_nodes, num_dim)
    B = np.array([zeta / (i + 1) * np.random.normal(0, 1, size=num_dim) for i in range(0, num_nodes)]) 
    return A, B

def fval(x, A, B):
    # x.shape = (num_dim,)
    # A.shape = (num_nodes, num_dim, num_dim)
    # B.shape = (num_nodes, num_dim)
    # return the value of f evaluated at x
    AxmB = np.einsum("ijk,k->ij", A, x) - B
    return np.mean(np.linalg.norm(AxmB, axis=1)) / 2.0

def argmin_f(A, B):
    # A.shape = (num_nodes, num_dim, num_dim)
    # B.shape = (num_nodes, num_dim)
    return np.linalg.inv(np.einsum("ijk,ikl->jl", A, A)).dot(np.einsum("ijk,ij->k", A, B))

def argmin_f_noinv(A, B):
    num_nodes, num_dim, _ = A.shape
    scaling_sum = sum(((i + 1) / np.sqrt(num_nodes)) ** 2 for i in range(num_nodes))
    return np.sum(np.einsum("ijk,ij->ik", A, B), axis=0) / scaling_sum

def min_f(A, B):
    # A.shape = (num_nodes, num_dim, num_dim)
    # B.shape = (num_nodes, num_dim)
    x_star = argmin_f_noinv(A, B)
    return fval(x_star, A, B)
def fdistance(x, A, B):
    # f-f*
    return fval(x, A, B) - min_f(A, B)

def xdistance(x, A, B): # ||x-x*||^2
    # x.shape = (num_dim,)
    # A.shape = (num_nodes, num_dim, num_dim)
    # B.shape = (num_nodes, num_dim)

    x_star = argmin_f_noinv(A, B)
    return np.linalg.norm(x-x_star) ** 2  / len(x)

def prob_compressor(x, delta):
    # C_delta(x) = x w.p. delta
    #            = 0 o/w
    if random.random() <= delta:
        return x
    else:
        return np.zeros_like(x)

def topK(vec: np.ndarray, delta: float) -> np.ndarray:
    # Compute the number of entries to keep
    k = int(np.ceil(len(vec) * delta))

    # Sort the absolute values of the input vector
    sorted_indices = np.argsort(np.abs(vec))

    # Create an array of zeros
    compressed_vec = np.zeros_like(vec)

    # Set the top k entries in the compressed vector
    compressed_vec[sorted_indices[-k:]] = vec[sorted_indices[-k:]]

    return compressed_vec


def stoch_gradient(x: np.array, A, B, sigma):
    # x.shape = (num_dim,)
    # A.shape = (num_nodes, num_dim, num_dim)
    # B.shape = (num_nodes, num_dim)
    num_nodes, num_dim = B.shape
    AxmB = np.einsum("ijk,k->ij", A, x) - B # shape (num_nodes, num_dim)
    grad = np.einsum("ijk,ij->ik", A, AxmB) # shape (num_nodes, num_dim)
    noise = (sigma / np.sqrt(num_dim)) * np.random.normal(0, 1, size=B.shape)
    return grad + noise # shape (num_nodes, num_dim)


def gridsearch(name, optimizer, compressor, delta, num_iter, xtol, ABs, num_nodes, num_dim, zeta_range, sigma, gamma_range, eta_range):
    best_gamma = np.zeros_like(zeta_range)
    best_eta = np.zeros_like(zeta_range)
    best_xdists = np.empty(len(zeta_range), dtype=object)
    for i, zeta in enumerate(zeta_range):
        A, B = ABs[i]
        best_performances = float('inf')
        best_error = float('inf')
        for j, gamma in enumerate(gamma_range):
            for k, eta in enumerate(eta_range):
                x = np.ones(num_dim)
                fdists, xdists, x_iter = optimizer(x, A, B, compressor, delta, gamma, eta, sigma, num_iter=num_iter, xtol=xtol)
                if xdists[-1] <= xtol:
                    best_error = xtol
                    if len(xdists) < best_performances:
                        best_gamma[i] = gamma
                        best_eta[i] = eta
                        best_performances = len(xdists)
                        best_xdists[i] = xdists
                else:
                    if xdists[-1] < best_error:
                        best_gamma[i] = gamma
                        best_eta[i] = eta
                        best_error = xdists[-1]
                        best_xdists[i] = xdists
    xdists_file = f'experiment-data/{name}_nodes{num_nodes}_dim{num_dim}_delta{delta}_sigma{sigma}.npy'
    np.save(xdists_file, best_xdists)
    gamma_file = f'experiment-data/{name}_gamma_nodes{num_nodes}_dim{num_dim}_delta{delta}_sigma{sigma}.json'
    data_dict = {"best_gamma": best_gamma.tolist()}
    with open(gamma_file, 'w') as json_file:
        json.dump(data_dict, json_file, indent=4)
    eta_file = f'experiment-data/{name}_eta_nodes{num_nodes}_dim{num_dim}_delta{delta}_sigma{sigma}.json'
    data_dict = {"best_eta": best_eta.tolist()}
    with open(eta_file, 'w') as json_file:
        json.dump(data_dict, json_file, indent=4)
    return best_gamma, best_eta, best_xdists


